from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM


class MptAWQForCausalLM(BaseAWQForCausalLM):
    layer_type = "MPTBlock"
    max_seq_len_key = "max_seq_len"

    @staticmethod
    def fuse_layers(model: MptForCausalLM):
        fuser = MptFuser(model)
        fuser.fuse_transformer()

    @staticmethod
    def get_model_layers(model: MptForCausalLM):
        return model.transformer.blocks

    @staticmethod
    def get_act_for_scaling(module: OldMptBlock):
        return dict(
            is_scalable=True,
            scale_name="ffn.act",
            scale_layer=module.ffn.act,
            scale_shape=module.ffn.up_proj.out_features,
        )

    @staticmethod
    def move_embed(model: MptForCausalLM, device: str):
        model.transformer.wte = model.transformer.wte.to(device)
        model.transformer.emb_drop = model.transformer.emb_drop.to(device)

    @staticmethod
    def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
        layers = []

        if module_kwargs.get("output_attentions") is not None:
            module_kwargs.pop("output_attentions")

        # attention input
        layers.append(
            dict(
                prev_op=module.norm_1,
                layers=[module.attn.Wqkv],
                inp=input_feat["attn.Wqkv"],
                module2inspect=module.attn,
                kwargs=module_kwargs,
            )
        )

        # attention output
        layers.append(
            dict(
                prev_op=module.attn.Wqkv,
                layers=[module.attn.out_proj],
                inp=input_feat["attn.out_proj"],
            )
        )

        # linear 1
        layers.append(
            dict(
                prev_op=module.norm_2,
                layers=[module.ffn.up_proj],
                inp=input_feat["ffn.up_proj"],
                module2inspect=module.ffn,
            )
        )

        # linear 2
        layers.append(
            dict(
                prev_op=module.ffn.act,
                layers=[module.ffn.down_proj],
                inp=input_feat["ffn.down_proj"],
            )
        )

        return layers


from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.block import MPTBlock
from awq.modules.fused.model import MPTModel


class MptFuser:
    def __init__(self, model: MptForCausalLM):
        self.model = model

        self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
            (name, module)
            for name, module in self.model.named_modules()
            if "mptblock" in module.__class__.__name__.lower()
        ]

    def fuse_transformer(self):
        blocks = []

        module: OldMptBlock
        for module in self.model.transformer.blocks:
            blocks.append(
                MPTBlock(
                    self.model.config.d_model,
                    self.model.config.n_heads,
                    module.attn.Wqkv,
                    module.attn.out_proj,
                    module.ffn,
                    module.norm_1,
                    module.norm_2,
                    next(iter(module.state_dict().values())).device,
                    self.model.config.max_seq_len,
                )
            )

        self.model.transformer = MPTModel(
            self.model.config.vocab_size,
            blocks,
            self.model.transformer.wte,
            self.model.transformer.norm_f,
        )

        setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
